use core::{
    alloc::Layout,
    marker::PhantomData,
    ptr::{slice_from_raw_parts, slice_from_raw_parts_mut, NonNull},
};

use log::trace;

use crate::{
    align::*,
    err::{PagingError, PagingResult},
    iter::TableIter,
    page_table_entry::Pte,
    Access, MapConfig, PTEArch, PTEGeneric, PTEInfo,
};

/// A reference to a page table.
///
/// `LEN` is the number of entries in a page table.
/// `LEVEL` is the max level of the page table.
#[derive(Clone, Copy)]
pub struct PageTableRef<'a, P: PTEArch> {
    addr: usize,
    walk: PageWalk,
    _marker: PhantomData<&'a P>,
}

impl<'a, P: PTEArch> PageTableRef<'a, P> {
    /// Creates a new page table reference.
    pub fn create_empty(access: &mut impl Access) -> PagingResult<Self> {
        Self::new_with_level(P::level(), access)
    }
    /// New page table and returns a reference to it.
    ///
    /// `level` is level of this page, should from 1 to up.
    pub fn new_with_level(level: usize, access: &mut impl Access) -> PagingResult<Self> {
        assert!(level > 0);
        let addr = unsafe { Self::alloc_table(access)? };
        Ok(PageTableRef::from_addr(addr, level))
    }

    pub fn from_addr(addr: usize, level: usize) -> Self {
        let walk = PageWalk::new(P::page_size(), level);

        Self {
            addr,
            walk,
            _marker: PhantomData,
        }
    }

    pub fn level(&self) -> usize {
        self.walk.level
    }

    pub fn paddr(&self) -> usize {
        self.addr
    }

    /// .
    ///
    /// # Errors
    ///
    /// This function will return an error if .
    ///
    /// # Safety
    /// User must ensure that the physical address is valid.
    pub unsafe fn map_region(
        &mut self,
        config: MapConfig,
        size: usize,
        allow_block: bool,
        access: &mut impl Access,
    ) -> PagingResult<()> {
        self.map_region_with_handle(
            config,
            size,
            allow_block,
            access,
            None::<fn(*const u8)>.as_ref(),
        )
    }

    /// Map a contiguous virtual memory region to a contiguous physical memory
    /// region with the given mapping `flags`.
    ///
    /// The virtual and physical memory regions start with `vaddr` and `paddr`
    /// respectively. The region size is `size`. The addresses and `size` must
    /// be aligned to 4K, otherwise it will return [`Err(PagingError::NotAligned)`].
    ///
    /// When `allow_huge` is true, it will try to map the region with huge pages
    /// if possible. Otherwise, it will map the region with 4K pages.
    ///
    /// [`Err(PagingError::NotAligned)`]: PagingError::NotAligned
    ///
    /// # Safety
    /// User must ensure that the physical address is valid.
    pub unsafe fn map_region_with_handle(
        &mut self,
        cfg: MapConfig,
        size: usize,
        allow_huge: bool,
        access: &mut impl Access,
        on_page_mapped: Option<&impl Fn(*const u8)>,
    ) -> PagingResult {
        let vaddr = cfg.vaddr;
        let paddr = cfg.paddr;

        if !vaddr.is_aligned_to(P::page_size()) {
            return Err(PagingError::NotAligned("vaddr"));
        }

        if !paddr.is_aligned_to(P::page_size()) {
            return Err(PagingError::NotAligned("paddr"));
        }

        let mut size = size;
        trace!(
            "map_region: [{:#x}, {:#x}) -> [{:#x}, {:#x}) {:?}",
            vaddr as usize,
            vaddr as usize + size,
            paddr,
            paddr + size,
            cfg.setting,
        );

        let mut map_cfg = cfg;

        while size > 0 {
            let level_deepth = if allow_huge {
                self.walk
                    .detect_align_level(map_cfg.vaddr, size)
                    .min(self.walk.detect_align_level(map_cfg.paddr as _, size))
            } else {
                1
            };
            self.get_entry_or_create(&map_cfg, level_deepth, access)?;

            let map_size = self.walk.copy_with_level(level_deepth).level_entry_size();

            if let Some(f) = on_page_mapped {
                f(vaddr);
            }
            map_cfg.vaddr = unsafe { map_cfg.vaddr.add(map_size) };
            map_cfg.paddr += map_size;
            size -= map_size;
        }
        Ok(())
    }

    pub fn as_slice(&self, access: &impl Access) -> &'a [Pte<P>] {
        unsafe {
            &*slice_from_raw_parts(
                (self.addr + access.va_offset()) as *const Pte<P>,
                self.walk.table_size,
            )
        }
    }

    unsafe fn sub_table_or_create(
        &mut self,
        idx: usize,
        map_cfg: &MapConfig,
        access: &mut impl Access,
    ) -> PagingResult<PageTableRef<'a, P>> {
        let mut pte = self.get_pte(idx, access);
        let sub_level = self.level() - 1;

        if pte.valid() {
            Ok(Self::from_addr(pte.paddr, sub_level))
        } else {
            let table = Self::new_with_level(sub_level, access)?;
            let ptr = table.addr;
            pte.is_valid = true;
            pte.paddr = ptr;
            pte.is_block = false;
            pte.setting = map_cfg.setting;

            let s = self.as_slice_mut(access);
            s[idx] = P::new_pte(pte);

            Ok(table)
        }
    }

    unsafe fn get_entry_or_create(
        &mut self,
        map_cfg: &MapConfig,
        level: usize,
        access: &mut impl Access,
    ) -> PagingResult<()> {
        let mut table = *self;
        while table.level() > 0 {
            let idx = table.index_of_table(map_cfg.vaddr);
            if table.level() == level {
                table.as_slice_mut(access)[idx] =
                    P::new_pte(PTEGeneric::new(map_cfg.paddr, level > 1, map_cfg.setting));
                return Ok(());
            }
            table = table.sub_table_or_create(idx, map_cfg, access)?;
        }
        Err(PagingError::NotAligned("vaddr"))
    }

    pub fn release(&mut self, access: &mut impl Access) {
        self._release(0usize as _, access);
        unsafe {
            access.dealloc(
                self.addr.to_virt(access),
                Layout::from_size_align_unchecked(P::page_size(), P::page_size()),
            );
        }
    }

    fn _release(&mut self, start_vaddr: *const u8, access: &mut impl Access) -> Option<()> {
        let start_vaddr_usize: usize = start_vaddr as _;
        let entries = self.as_slice(access);

        if self.level() == 1 {
            return Some(());
        }

        for (i, entry) in entries.iter().enumerate() {
            let vaddr_usize = start_vaddr_usize + i * self.entry_size();
            let vaddr = vaddr_usize as _;
            let pte = entry.read();

            if pte.valid() {
                let is_block = pte.is_block;

                if self.level() > 1 && !is_block {
                    let mut table_ref = self.next_table(i, access)?;
                    table_ref._release(vaddr, access)?;

                    unsafe {
                        access.dealloc(
                            pte.paddr.to_virt(access),
                            Layout::from_size_align_unchecked(P::page_size(), P::page_size()),
                        );
                    }
                }
            }
        }
        Some(())
    }

    fn next_table(&self, idx: usize, access: &impl Access) -> Option<Self> {
        let pte = self.get_pte(idx, access);
        if pte.is_block {
            return None;
        }

        if pte.valid() {
            Some(Self::from_addr(pte.paddr, self.level() - 1))
        } else {
            None
        }
    }

    fn index_of_table(&self, vaddr: *const u8) -> usize {
        self.walk.index_of_table(vaddr)
    }

    // 每个页表项的大小
    pub fn entry_size(&self) -> usize {
        self.walk.level_entry_size()
    }

    pub fn table_size(&self) -> usize {
        self.walk.table_size
    }

    fn as_slice_mut(&mut self, access: &impl Access) -> &'a mut [usize] {
        unsafe {
            &mut *slice_from_raw_parts_mut(
                (self.addr + access.va_offset()) as *mut usize,
                self.walk.table_size,
            )
        }
    }

    fn get_pte(&self, idx: usize, access: &impl Access) -> PTEGeneric {
        let s = self.as_slice(access);
        s[idx].read()
    }

    unsafe fn alloc_table(access: &mut impl Access) -> PagingResult<usize> {
        let page_size = P::page_size();
        let layout = Layout::from_size_align_unchecked(page_size, page_size);
        if let Some(addr) = access.alloc(layout) {
            addr.write_bytes(0, page_size);
            Ok(addr.as_ptr() as usize - access.va_offset())
        } else {
            Err(PagingError::NoMemory)
        }
    }

    pub fn iter_all<A: Access>(&self, access: &'a A) -> impl Iterator<Item = PTEInfo> + 'a {
        TableIter::new(0 as _, *self, access)
    }
}

const fn log2(value: usize) -> usize {
    assert!(value > 0, "Value must be positive and non-zero");
    match value {
        512 => 9,
        4096 => 12,
        _ => {
            let mut v = value;
            let mut result = 0;

            // 计算最高位的位置
            while v > 1 {
                v >>= 1; // 右移一位
                result += 1;
            }

            result
        }
    }
}

pub trait PVConvert {
    fn to_virt<T>(&self, access: &impl Access) -> NonNull<T>;
}

impl PVConvert for usize {
    fn to_virt<T>(&self, access: &impl Access) -> NonNull<T> {
        unsafe { NonNull::new_unchecked((self + access.va_offset()) as *mut u8) }.cast()
    }
}

#[derive(Debug, Clone, Copy)]
pub struct PageWalk {
    level: usize,
    table_size: usize,
    table_size_pow: usize,
    page_size_pow: usize,
}

impl PageWalk {
    fn new(page_size: usize, level: usize) -> Self {
        let table_size = page_size / size_of::<usize>();
        let table_size_pow = log2(table_size);
        let page_size_pow = log2(page_size);

        Self {
            table_size,
            table_size_pow,
            page_size_pow,
            level,
        }
    }

    fn copy_with_level(&self, level: usize) -> Self {
        let mut c = *self;
        c.level = level;
        c
    }

    fn level_entry_size_shift(&self) -> usize {
        self.page_size_pow + (self.level - 1) * self.table_size_pow
    }

    fn index_of_table(&self, vaddr: *const u8) -> usize {
        (vaddr as usize >> self.level_entry_size_shift()) & (self.table_size - 1)
    }

    fn level_entry_size(&self) -> usize {
        1 << self.level_entry_size_shift()
    }

    fn detect_align_level(&self, vaddr: *const u8, size: usize) -> usize {
        for level in (0..self.level).rev() {
            let level_size = self.copy_with_level(level).level_entry_size();
            if vaddr as usize % level_size == 0 && size >= level_size {
                return level;
            }
        }
        1
    }
}

#[cfg(test)]
mod test {
    use super::*;

    const MB: usize = 1024 * 1024;
    const GB: usize = 1024 * MB;

    #[test]
    fn test_log2() {
        assert_eq!(log2(512), 9);
        assert_eq!(log2(4096), 12);
    }

    #[test]
    fn test_level_entry_memory_size() {
        assert_eq!(PageWalk::new(4096, 1).level_entry_size(), 4096);
        assert_eq!(PageWalk::new(4096, 2).level_entry_size(), 2 * MB);
        assert_eq!(PageWalk::new(4096, 3).level_entry_size(), GB);
        assert_eq!(PageWalk::new(4096, 4).level_entry_size(), 512 * GB);
    }

    #[test]
    fn test_idx_of_table() {
        let w = PageWalk::new(4096, 1);
        assert_eq!(w.index_of_table(0 as _), 0);
        assert_eq!(w.index_of_table(0x1000 as _), 1);
        assert_eq!(w.index_of_table(0x2000 as _), 2);

        let w = PageWalk::new(4096, 2);
        assert_eq!(w.index_of_table(0 as _), 0);
        assert_eq!(w.index_of_table((2 * MB) as _), 1);

        let w = PageWalk::new(4096, 3);
        assert_eq!(w.index_of_table(GB as _), 1);

        let w = PageWalk::new(4096, 4);
        assert_eq!(w.index_of_table((512 * GB) as _), 1);
    }

    #[test]
    fn test_detect_align() {
        let s = 4 * GB;

        let w = PageWalk::new(0x1000, 4);
        assert_eq!(w.detect_align_level(0x1000 as _, s), 1);

        assert_eq!(w.detect_align_level((0x1000 * 512) as _, s), 2);

        assert_eq!(w.detect_align_level((0x1000 * 512 * 512) as _, s), 3);

        assert_eq!(w.detect_align_level((2 * GB) as _, s), 3);
    }
}
